import os, math, time, datetime, paddle, visualdl
import numpy as np
from matplotlib import pyplot as plt

from src.data import DataLoader
from src.model.ssyolo import SSYOLO

class Trainer():
    def __init__(self,
                 train_batch=2,
                 valid_batch=2,
                 epoch_num  =600,
                 eval_save  =5,
                 iter_show  =3,
                 lin_epoch  =10,
                 lin_value  =0.0,
                 lrate_max  =0.01,
                 cos_value  =0.0,
                 pie_epoch  =None,
                 pie_value  =None,
                 optimizer  ='AdamW',
                 l2_decays  =0.0005,
                 num_classes=4,
                 bbox_thresh=0.50,
                 load_flag  ='none',
                 train_txt  ='./data/train.txt',
                 valid_txt  ='./data/valid.txt',
                 label_txt  ='./data/label.txt',
                 save_path  ='./out/',
                 logs_path  ='./log/'):
        """
        初始化训练器
        params:
        - train_batch: 训练批次大小
        - valid_batch: 验证批次大小
        - epoch_num  : 训练轮次总数
        - eval_save  : 验证保存频率，每隔多少轮验证并保存一次
        - iter_show  : 迭代显示次数，每个轮次中迭代显示的次数
        - lin_epoch  : 线性预热轮数
        - lin_value  : 线性始学习率
        - lrate_max  : 最大学习率值
        - cos_value  : 余弦止学习率，当分段衰减为空时使用余弦
        - pie_epoch  : 分段衰减轮数
        - pie_value  : 分段学习率值
        - optimizer  : 权重值优化器，必须为AdamW或Momentum
        - l2_decays  : 权重衰减系数
        - num_classes: 物体类别数量
        - bbox_thresh: 验证边框阈值
        - load_flag  : 加载断点标识，none,resume,finetune
        - train_txt  : 训练数据路径
        - valid_txt  : 验证数据路径
        - label_txt  : 标签文件路径
        - save_path  : 模型保存路径
        - logs_path  : 日志保存路径
        """
        # 设置变量
        self.train_loader = DataLoader(train_txt, label_txt, train_batch, worker_num=0, mode='train') # 训练数据
        self.valid_loader = DataLoader(valid_txt, label_txt, valid_batch, worker_num=0, mode='valid') # 验证数据
        
        self.epoch_num = epoch_num                  # 训练轮次总数
        self.epoch_len = len(str(epoch_num))        # 轮次总数位数
        self.epoch_cur = 0                          # 当前训练轮次
        
        self.eval_save = eval_save                  # 验证保存频率
        self.iters_num = len(self.train_loader)     # 每轮迭代次数
        self.iters_all = self.iters_num * epoch_num # 训练迭代总数
        self.iter_show = self.iters_num//iter_show  # 迭代显示次数
        self.iters_len = len(str(self.iters_num))   # 迭代显示位数
        self.iters_cur = 0                          # 当前训练代数
        
        self.num_classes = num_classes              # 物体类别数量
        self.bbox_thresh = bbox_thresh              # 验证边框阈值
        self.save_path   = save_path                # 模型保存路径
        self.logs_path   = logs_path                # 日志保存路径
                
        self.train_loss_list = []                   # 训练损失列表
        self.valid_loss_list = []                   # 验证损失列表
        self.valid_mAP_list  = []                   # 验证精度列表
        self.best_epoch = 0                         # 验证最好轮次
        self.best_lr    = 0.0                       # 验证现学习率
        self.best_loss  = 1e6                       # 验证最好损失
        self.best_mAP   = 0.0                       # 验证最好精度
        
        # 声明模型
        self.model = SSYOLO(num_classes=num_classes)
        
        # 优化算法
        assert optimizer in ['AdamW', 'Momentum'], "错误：损失值优化器必须为'AdamW','Momentum'"
        if pie_epoch is None or pie_value is None: # 是否选择余弦学习率策略
            cos_iters = self.iters_num * (epoch_num - lin_epoch)                                      # 余弦衰减数
            lin_lrate = paddle.optimizer.lr.CosineAnnealingDecay(lrate_max, cos_iters, cos_value)     # 余弦学习率
        else:                                      # 否则选择分段学习率策略
            pie_iters = [self.iters_num * (i - lin_epoch) for i in pie_epoch]                         # 分段衰减数
            lin_lrate = paddle.optimizer.lr.PiecewiseDecay(pie_iters, pie_value)                      # 分段学习率
        lin_iters = self.iters_num * lin_epoch                                                        # 线性预热数
        self.scheduler = paddle.optimizer.lr.LinearWarmup(lin_lrate, lin_iters, lin_value, lrate_max) # 线性学习率
        if optimizer == 'AdamW':
            self.optimizer = paddle.optimizer.AdamW(                                                  # 适应优化器
                learning_rate=self.scheduler, weight_decay=l2_decays, parameters=self.model.parameters()
            )
        else:
            self.optimizer = paddle.optimizer.Momentum(                                               # 动量优化器
                learning_rate=self.scheduler, weight_decay=l2_decays, parameters=self.model.parameters()
            )
        
        # 加载断点
        assert load_flag in ['none', 'resume', 'finetune'], "错误：加载断点标识必须为'none','resume','finetune'"
        if load_flag in ['resume', 'finetune']: # 如果为恢复微调
            # 加载权重与参数
            self.epoch_cur = self.load_model(load_flag)
        else:                                   # 否则从初始训练
            # 创建模型文件夹
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            
            # 清空模型文件夹
            for root, dirs, files in os.walk(save_path, topdown=False):
                for name in files:
                    os.remove(os.path.join(root, name))
                for name in dirs:
                    os.rmdir(os.path.join(root, name))
            
            # 创建日志文件夹
            if not os.path.exists(logs_path):
                os.makedirs(logs_path)
            
            # 清空日志文件夹
            for root, dirs, files in os.walk(logs_path, topdown=False):
                for name in files:
                    os.remove(os.path.join(root, name))
                for name in dirs:
                    os.rmdir(os.path.join(root, name))
        
        # 日志记录
        self.trian_log = visualdl.LogWriter(os.path.join(logs_path, 'train/'), file_name='vdlrecords.train.log')
        self.valid_log = visualdl.LogWriter(os.path.join(logs_path, 'valid/'), file_name='vdlrecords.valid.log')
        
        # 训练超参
        self.valid_log.add_hparams(
            hparams_dict={'epoch': epoch_num, 'batch size': train_batch, 'learning rate': lrate_max, 'warmup': lin_epoch, 
                          'optimizer': optimizer, 'l2 decays': l2_decays},
            metrics_list=['Train/Loss', 'Train/mAP']
        )
        
        # 验证精度
        self.metric = DetectionMAP(num_classes, bbox_thresh)
        
    def train(self):
        """
        训练检测网络
        """
        # 训练网络
        if self.epoch_cur >= self.epoch_num:                                               # 如果当前轮次大于轮次总数
            print('ENDED - current epochs greater than or equals total number of epochs!') # 则打印结束信息并结束训练
            return
        for epoch_id in range(self.epoch_cur, self.epoch_num):    # 遍历轮次
            # 训练数据
            time_list = []          # 迭代时间列表
            iter_time = time.time() # 设置当前时间
            train_losses = []       # 训练损失列表
            self.model.train()      # 设置训练模式
            for batch_id, inputs in enumerate(self.train_loader): # 遍历数据
                # 读取数据
                images = inputs['image']                          # 图像数据
                
                # 前向传播
                p_list = self.model(images)                       # 提取特征

                # 计算损失
                losses = self.model.get_loss(p_list, inputs)      # 计算损失
                train_losses.append(losses.numpy()[0])            # 添加损失
                
                # 反向传播
                losses.backward()                                 # 反向传播
                self.optimizer.step()                             # 更新参数
                self.optimizer.clear_grad()                       # 梯度清零
                self.scheduler.step()                             # 更新学习
                
                # 显示训练
                self.iters_cur = epoch_id * self.iters_num + batch_id + 1        # 设置当前代数
                time_list.append(time.time() - iter_time)                        # 添加迭代时间
                iter_time = time.time()                                          # 设置当前时间
                if (batch_id + 1) % self.iter_show == 0:                         # 是否显示训练
                    epoch = epoch_id + 1                                         # 获取当前轮数
                    batch = batch_id + 1                                         # 获取训练批次
                    lr = self.optimizer.get_lr()                                 # 获当前学习率
                    
                    train_loss = np.mean(train_losses)                           # 计算平均损失
                    self.train_loss_list.append([self.iters_cur, train_loss])    # 添加训练损失
                    train_losses = []                                            # 清空损失列表
                    
                    eta_sec = (self.iters_all-self.iters_cur)*np.mean(time_list) # 计算剩余秒数
                    eta = 'ETA:' + str(datetime.timedelta(seconds=int(eta_sec))) # 设置剩余时间
                    time_list = []                                               # 清空时间列表
                    
                    self.save_result('TRAIN', epoch, batch, lr, train_loss, eta) # 保存训练结果
            
            # 验证数据
            if (epoch_id + 1) % self.eval_save == 0: # 是否验证保存
                # 验证数据
                valid_losses = []                    # 验证损失列表
                self.model.eval()                    # 设置验证模式
                for batch_id, inputs in enumerate(self.valid_loader): # 遍历数据
                    # 读取数据
                    images = inputs['image']                          # 图像数据
                    imghws = inputs['imghw']                          # 图像高宽

                    # 前向传播
                    p_list = self.model(images)                       # 提取特征

                    # 计算损失
                    losses = self.model.get_loss(p_list, inputs)      # 计算损失
                    valid_losses.append(losses.numpy()[0])            # 添加损失

                    # 计算预测
                    infers = self.model.get_pred(p_list, imghws)      # 计算预测
                    self.metric.updates(infers, inputs)               # 添加预测
                    
                # 显示验证
                epoch = epoch_id + 1                                         # 获取当前轮数
                batch = len(valid_losses)                                    # 获取验证批数
                lr = self.optimizer.get_lr()                                 # 获当前学习率

                valid_loss = np.mean(valid_losses)                           # 计算平均损失
                self.valid_loss_list.append([self.iters_cur, valid_loss])    # 添加验证损失
                valid_losses = []                                            # 清空损失列表

                valid_mAP = self.metric.mAP()                                # 计算平均精度
                self.valid_mAP_list.append([self.iters_cur, valid_mAP])      # 添加验证精度
                self.metric.reset()                                          # 重置验证精度

                mAP = f'mAP({self.bbox_thresh:.2f}):{valid_mAP:.4f}'         # 设置验证精度

                self.save_result('VALID', epoch, batch, lr, valid_loss, mAP) # 保存验证结果

                # 保存模型
                self.save_model(epoch, lr, valid_loss, valid_mAP)            # 保存模型参数
                
                # 保存日志
                self.trian_log.add_scalar(tag='Train/Loss', step=epoch, value=train_loss) # 添加训练损失
                self.valid_log.add_scalar(tag='Train/Loss', step=epoch, value=valid_loss) # 添加验证损失
                self.valid_log.add_scalar(tag='Train/mAP' , step=epoch, value=valid_mAP ) # 添加验证精度
        
        # 保存结果
        message = f'ENDED - best epoch: {self.best_epoch}, '+\
                  f'lr: {self.best_lr:.6f}, '+\
                  f'loss: {self.best_loss:.6f}, '+\
                  f'mAP({self.bbox_thresh:.2f}): {self.best_mAP:.4f}, '+\
                  f'logs path: {self.logs_path}'                      # 完成信息
        print(message)                                                # 打印信息
        with open(os.path.join(self.logs_path, 'log.txt'), 'a') as f: # 打开文件
            f.write(message + '\n')                                   # 保存结果
        self.save_figure()                                            # 保存图表
            
    def load_model(self, load_flag):
        """
        加载模型参数
        params:
        - load_flag: 加载断点标识
        return:
        - epoch_cur: 当前训练轮次
        """
        # 检测模型目录
        epoch_cur = 0                          # 设置当前轮次
        if not os.path.exists(self.save_path): # 若不存在目录
            os.makedirs(self.save_path)        # 创建保存目录
            print('警告：模型路径不存在！')        # 打印警告信息
            return epoch_cur
        
        # 加载模型权重
        if self.model is not None and os.path.exists(os.path.join(self.save_path, 'model.pdparams')): # 是否加载权重
            model_state_dict = paddle.load(os.path.join(self.save_path, 'model.pdparams'))            # 加载模型权重
            self.model.set_state_dict(model_state_dict)                                               # 设置模型权重
        else:
            print('警告：模型参数加载失败！')

        # 清空日志文件
        if load_flag == 'finetune': # 如果为微调训练，则清空日志文件
            for root, dirs, files in os.walk(self.logs_path, topdown=False):
                for name in files:
                    os.remove(os.path.join(root, name))
                for name in dirs:
                    os.rmdir(os.path.join(root, name))
        
        # 加载训练参数
        if load_flag == 'resume':   # 如果为恢复训练，加载优化器参数
            if self.optimizer is not None and os.path.exists(os.path.join(self.save_path, 'model.pdopt')): # 是否加载参数
                optimizer_state_dict = paddle.load(os.path.join(self.save_path, 'model.pdopt'))            # 加载训练参数
                if 'epoch_cur' in optimizer_state_dict:                                                    # 是否存在轮数
                    epoch_cur = optimizer_state_dict.pop('epoch_cur')                                      # 设置当前轮数
                self.optimizer.set_state_dict(optimizer_state_dict)                                        # 设置训练参数
            else:
                print('警告：优化参数加载失败！')
        
        return epoch_cur
    
    def save_model(self, epoch, lr, loss, mAP):
        """
        保存模型参数
        params:
        - epoch: 当前轮次
        - lr   : 现学习率
        - loss : 当前损失
        - mAP  : 当前精度
        """
        # 是否存在目录
        if not os.path.exists(self.save_path): # 若不存在目录
            os.makedirs(self.save_path)        # 创建保存目录
        
        # 保存模型权重
        model_state_dict = self.model.state_dict()                                        # 读取模型权重
        paddle.save(model_state_dict, os.path.join(self.save_path, 'model.pdparams'))     # 保存模型权重
        
        # 保存训练参数
        optimizer_state_dict = self.optimizer.state_dict()                                # 读取优化参数
        optimizer_state_dict['epoch_cur'] = epoch                                         # 设置训练轮次
        paddle.save(optimizer_state_dict, os.path.join(self.save_path, 'model.pdopt'))    # 保存优化参数
        
        # 保存最好模型
        if mAP >= self.best_mAP:                                                          # 是否最好精度
            self.best_epoch = epoch                                                       # 更新最好轮次
            self.best_lr    = lr                                                          # 更新现学习率
            self.best_loss  = loss                                                        # 更新最好损失
            self.best_mAP   = mAP if mAP < 1.0 else self.best_mAP                         # 更新最好精度
            paddle.save(model_state_dict, os.path.join(self.save_path, 'great.pdparams')) # 保存模型权重
    
    def save_result(self, mode, epoch, batch, lr, loss, info):
        """
        保存结果信息
        params:
        - mode : 模式名称
        - epoch: 轮次数值
        - batch: 批次数值
        - lr   : 学习率值
        - loss : 损失数值
        - info : 模式信息
        """
        assert mode in ['TRAIN', 'VALID'], '错误：mode必须为"train"或"valid"'
        message = f'{mode} - {time.strftime("%m/%d %H:%M:%S", time.localtime())}, ' +\
                  f'epoch:{str(epoch).rjust(self.epoch_len)}/{self.epoch_num}, ' +\
                  f'batch:{str(batch).rjust(self.iters_len)}, ' +\
                  f'lr:{lr:.6f}, ' +\
                  f'loss:{loss:12.6f}, ' +\
                  f'{info}'                                           # 保存信息
        print(message)                                                # 打印信息
        with open(os.path.join(self.logs_path, 'log.txt'), 'a') as f: # 打开文件
            f.write(message + '\n')                                   # 写入信息
        
    def save_figure(self):
        """
        保存结果图表
        """
        # 创建绘制画布
        plt.figure(figsize=(18, 6))
        
        # 绘制损失子图
        ax1 = plt.subplot(1, 2, 1)                                # 获取损失子图
        
        x1, y1 = np.array(self.train_loss_list).transpose((1, 0)) # 训练损失坐标
        ax1.plot(x1, y1, color='b', label='train loss')           # 训练损失子图
        
        x2, y2 = np.array(self.valid_loss_list).transpose((1, 0)) # 验证损失坐标
        ax1.plot(x2, y2, color='g', label='valid loss')           # 验证损失子图
        
        ax1.set_xlabel("Iter")                                    # 设置横轴标签
        ax1.set_ylabel("Loss")                                    # 设置纵轴标签
        ax1.set_xlim(min(x2), max(x2))                            # 设置横轴范围
        ax1.set_ylim(0, max(y2) + 1)                              # 设置纵轴范围
        ax1.legend()                                              # 显示图表标注
        
        # 绘制精度子图
        ax2 = plt.subplot(1, 2, 2)                                # 获取精度子图

        x3, y3 = np.array(self.valid_mAP_list).transpose((1, 0))  # 验证精度坐标
        ax2.plot(x3, y3, color='r', label='valid mAP')            # 验证精度子图

        ax2.set_xlabel("Iter")                                    # 设置横轴标签
        ax2.set_ylabel("mAP ")                                    # 设置纵轴标签
        ax2.set_xlim(min(x3), max(x3))                            # 设置横轴范围
        ax2.set_ylim(0, 1)                                        # 设置纵轴范围
        ax2.legend()                                              # 显示图表标注
        
        # 保存绘制图表
        plt.savefig(os.path.join(self.logs_path, 'log.png'), bbox_inches='tight') # 保存图表
        plt.tight_layout()                                                        # 紧凑布局
        plt.show()                                                                # 显示图表

####################################################################################

class DetectionMAP(object):
    def __init__(self, num_classes=4, bbox_thresh=0.5):
        """
        初始计算精度方法
        params:
        - num_classes: 物体类别数量
        - bbox_thresh: 验证边框阈值
        """
        self.num_classes = num_classes # 预测类别数量
        self.bbox_thresh = bbox_thresh # 验证边框阈值
        self.reset()                   # 重置计算精度
        
    def reset(self):
        """
        重置计算精度方法
        params:
        """
        self.count = [0] * self.num_classes                 # 数量统计列表
        self.score = [ [] for _ in range(self.num_classes)] # 得分统计列表
        
    def updates(self, infers, inputs):
        """
        更新批次预测统计
        params:
        - infers: 预测结果
        - inputs: 输入数据
        """
        # 读取输入数据
        gtboxes = inputs['gtbox'].numpy() # [b,50,4]
        gtclses = inputs['gtcls'].numpy() # [b,50,1]
        imghwes = inputs['imghw'].numpy() # [b,2]
        
        # 更新预测统计
        for i, infer in enumerate(infers): # 遍历批次
            # 是否预测为空
            if len(infer) < 1: # 如果本批预测结果为空，那么进行下批预测统计
                continue
            
            # 计算边框类别
            gtbox, gtcls = self.get_gtbox_gtcls(gtboxes[i], gtclses[i], imghwes[i])
            
            # 更新预测统计
            self.update(infer, gtbox, gtcls)
            
    def get_gtbox_gtcls(self, gtbox, gtcls, imghw):
        """
        计算物体边框和类别
        params:
        - gtbox: 物体边框
        - gtcls: 物体类别
        - imghw: 物体高宽
        return:
        - gtbox: 物体边框
        - gtcls: 物体类别
        """
        # 调整坐标格式
        gtbox[:, 0] = gtbox[:, 0] - gtbox[:, 2] / 2.0 # 物体边框x1
        gtbox[:, 1] = gtbox[:, 1] - gtbox[:, 3] / 2.0 # 物体边框y1
        gtbox[:, 2] = gtbox[:, 0] + gtbox[:, 2]       # 物体边框x2
        gtbox[:, 3] = gtbox[:, 1] + gtbox[:, 3]       # 物体边框y2

        # 计算原图坐标
        gtbox[:, 0] = gtbox[:, 0] * imghw[1] # 物体边框x1
        gtbox[:, 1] = gtbox[:, 1] * imghw[0] # 物体边框y1
        gtbox[:, 2] = gtbox[:, 2] * imghw[1] # 物体边框x2
        gtbox[:, 3] = gtbox[:, 3] * imghw[0] # 物体边框y2
        
        # 统计有效边框
        count = 0                       # 设置有效计数
        for i in range(gtbox.shape[0]): # 遍历边框坐标
            if gtbox[i][0] == 0 and gtbox[i][1] == 0 and gtbox[i][2] == 0 and gtbox[i][3] == 0:
                break
            count += 1                  # 增加有效计数
            
        # 设置边框类别
        gtbox = gtbox[:count] # 设置物体边框
        gtcls = gtcls[:count] # 设置物体类别
        
        return gtbox, gtcls
        
    def update(self, infer, gtbox, gtcls):
        """
        更新预测统计
        params: 
        - infer: 预测结果
        - gtbox: 物体边框
        - gtcls: 物体类别
        """
        # 统计各类数量
        for gtcls_item in gtcls:
            self.count[int(np.array(gtcls_item))] += 1 # 增加类别计数
        
        # 统计各类得分
        visited = [False] * len(gtcls)                 # 各类访问标识
        for infer_item in infer:
            # 获取预测数据
            pdcls, pdsco, xmin, ymin, xmax, ymax = infer_item.tolist() # 获取预测数据
            pdbox = [xmin, ymin, xmax, ymax]                           # 获取预测边框
            
            # 计算最大边框
            max_index = -1 # 设置最大交并索引
            max_iou = -1.0 # 设置最大交并比值
            for i, gtcls_item in enumerate(gtcls): # 遍历真实类别列表
                if int(gtcls_item) == int(pdcls):  # 如果真实类别等于预测类别，则计算交并比值
                    iou = self.get_iou_xyxy(pdbox, gtbox[i])
                    if iou > max_iou: # 若交并比值大于最大值，则更新索引和交并比
                        max_index = i # 设置最大索引数值
                        max_iou = iou # 设置最大交并比值
            
            # 统计各类得分
            if max_iou > self.bbox_thresh: # 如果最大交并比值大于验证边框阈值
                if not visited[max_index]: # 如果该物体没有被统计，则添加到列表，并设置访问标识为真
                    self.score[int(pdcls)].append([pdsco, 1.0]) # 添加各类正确正例
                    visited[max_index] = True                   # 设置访问标识为真
                else:                      # 如果该物体已经被统计，则添加到列表，并设置为成错误正例
                    self.score[int(pdcls)].append([pdsco, 0.0]) # 添加各类错误正例
            else:                          # 如果最大交并比不大于验证边框阈值，则添加到列表，并设置成错误正例
                self.score[int(pdcls)].append([pdsco, 0.0])     # 添加各类错误正例
        
    def get_iou_xyxy(self, box1, box2):
        """
        计算交并比值
        params:
        - box1: 物体边框1，ndarray类型
        - box2: 物体边框2，ndarray类型
        return:
        - iou : 交并比值
        """
        # 计算交集面积
        x_min = max(box1[0], box2[0]) # 边框1与边框2的x1坐标
        y_min = max(box1[1], box2[1]) # 边框1与边框2的y1坐标
        x_max = min(box1[2], box2[2]) # 边框1与边框2的x2坐标
        y_max = min(box1[3], box2[3]) # 边框1与边框2的y2坐标

        intersection = np.maximum(x_max - x_min, 0) * np.maximum(y_max - y_min, 0) # 边框1与边框2交集面积
        
        # 计算并集面积
        s1 = np.maximum(box1[2] - box1[0], 0) * np.maximum(box1[3] - box1[1], 0) # 边框1的面积
        s2 = np.maximum(box2[2] - box2[0], 0) * np.maximum(box2[3] - box2[1], 0) # 边框2的面积

        union = s1 + s2 - intersection + 1e-9 # 边框1与边框2并集面积
        
        # 计算交并比值
        iou = intersection / union
        
        return iou
    
    def mAP(self):
        """
        计算验证精度
        return:
        - mAP:各类平均精度
        """
        # 计算每类精度
        mAP = 0                                           # 各类平均精度
        cnt = 0                                           # 各类类别计数
        for score, count in zip(self.score, self.count):  # 遍历每类物体
            # 统计正误正例
            if count == 0:      # 如果该类数量为零，则继续下一个类别
                continue
            if len(score) == 0: # 如果得分列表为空，则继续下一个类别
                cnt += 1                                  # 增加类别计数
                continue
            tp_list, fp_list = self.get_tp_fp_list(score) # 统计正误正例
            
            # 计算预测准确率和召回率
            precision = []                              # 准确率列表
            recall = []                                 # 召回率列表
            for tp, fp in zip(tp_list, fp_list):
                precision.append(float(tp) / (tp + fp)) # 添加准确率
                recall.append(float(tp) / count)        # 添加召回率
            
            # 计算平均精度
            AP = 0.0                                           # 设置平均精度
            pre_recall = 0.0                                   # 设置前召回率
            for i in range(len(precision)):                    # 遍历正确率表
                recall_gap = math.fabs(recall[i] - pre_recall) # 计算召回差值
                if recall_gap > 1e-6:                          # 若召回率改变，计算平均精度，更新前召回率
                    AP += precision[i] * recall_gap            # 累加平均精度
                    pre_recall = recall[i]                     # 更新前召回率
            
            # 更新各类精度
            mAP += AP # 累加各类精度
            cnt += 1  # 增加类别计数
            
        # 计算平均精度
        mAP = (mAP / float(cnt)) if cnt > 0 else mAP
        
        return mAP

    def get_tp_fp_list(self, score):
        """
        对得分列表进行从大到小排序，按排序统计正确正例和错误正例数量
        params:
        - score  : 得分列表
        return:
        - tp_list: 正确正例列表
        - fp_list: 错误正例列表
        """
        # 得分排序
        tp = 0                                                       # 正确正例数量
        fp = 0                                                       # 错误正例数量
        tp_list = []                                                 # 正确正例列表
        fp_list = []                                                 # 错误正例列表
        score_list = sorted(score, key=lambda s: s[0], reverse=True) # 从大到小排序
        
        # 统计列表
        for (score, label) in score_list: # 遍历得分列表
            tp += int(label)              # 统计正确正例
            tp_list.append(tp)            # 添加正确正例
            fp += 1 - int(label)          # 统计错误正例
            fp_list.append(fp)            # 添加错误正例
        
        return tp_list, fp_list

####################################################################################

if __name__ == "__main__":
    # 实例化训练器
    trainer = Trainer(
        train_batch=2,                        # 训练批次大小
        valid_batch=2,                        # 验证批次大小
        epoch_num  =600,                      # 训练轮次总数
        eval_save  =5,                        # 验证保存频率，每隔多少轮验证并保存一次
        iter_show  =3,                        # 迭代显示次数，每个轮次中迭代显示的次数
        lin_epoch  =10,                       # 线性预热轮数
        lin_value  =0.0,                      # 线性始学习率
        lrate_max  =0.01,                     # 最大学习率值
        cos_value  =0.0,                      # 余弦止学习率，当分段衰减为空时使用余弦
        pie_epoch  =[360, 480],               # 分段衰减轮数
        pie_value  =[0.01, 0.001, 0.0001],    # 分段学习率值
        optimizer  ='AdamW',                  # 权重值优化器，必须为AdamW或Momentum
        l2_decays  =0.0005,                   # 权重衰减系数
        num_classes=4,                        # 物体类别数量
        bbox_thresh=0.50,                     # 验证边框阈值
        load_flag  ='none',                   # 加载断点标识，none,resume,finetune
        train_txt  ='./data/train.txt',       # 训练数据路径
        valid_txt  ='./data/valid.txt',       # 验证数据路径
        label_txt  ='./data/label.txt',       # 标签文件路径
        save_path  ='./out/',                 # 模型保存路径
        logs_path  ='./log/'                  # 日志保存路径
    )
    
    # 启动训练数据
    trainer.train()